//! Route to handlers based on HTTP methods.

#[cfg(feature = "openapi")]
pub mod openapi;

use std::fmt::{Debug, Formatter};

use crate::error::MethodNotAllowed;
use crate::handler::{BoxRequestHandler, into_box_request_handler};
use crate::request::Request;
use crate::response::Response;
use crate::{Method, RequestHandler};

/// A router that routes requests based on the HTTP method.
///
/// This router allows you to register different handlers for different HTTP
/// methods at the same path. When a request is received, the router will
/// dispatch it to the handler registered for that HTTP method.
///
/// If no handler is registered for a particular method, the router will return
/// a [405 Method Not Allowed] response. If no handler is registered for
/// [`HEAD`] requests, the router will return the response generated by the
/// handler for [`GET`] requests.
///
/// [405 Method Not Allowed]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/405
/// [`HEAD`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Methods/HEAD
/// [`GET`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Methods/GET
///
/// # Examples
///
/// ```
/// use cot::html::Html;
/// use cot::router::method::{MethodRouter, get};
/// use cot::router::{Route, Router};
/// use cot::test::TestRequestBuilder;
///
/// async fn get_handler() -> Html {
///     Html::new("GET response")
/// }
///
/// async fn post_handler() -> Html {
///     Html::new("POST response")
/// }
///
/// # #[tokio::main]
/// # async fn main() -> cot::Result<()> {
/// let router = Router::with_urls([Route::with_handler(
///     "/",
///     get(get_handler).post(post_handler),
/// )]);
///
/// let request = TestRequestBuilder::get("/").router(router.clone()).build();
/// assert_eq!(
///     router
///         .handle(request)
///         .await?
///         .into_body()
///         .into_bytes()
///         .await?,
///     "GET response"
/// );
///
/// let request = TestRequestBuilder::post("/").router(router.clone()).build();
/// assert_eq!(
///     router
///         .handle(request)
///         .await?
///         .into_body()
///         .into_bytes()
///         .await?,
///     "POST response"
/// );
/// # Ok(())
/// # }
/// ```
#[derive(Debug)]
#[must_use]
pub struct MethodRouter {
    inner: InnerMethodRouter<InnerHandler>,
}

macro_rules! define_method {
    ($name:ident => $method:ident) => {
        #[doc = concat!("Set a handler for the [`",
                    stringify!($method),
                    "`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Methods/",
                    stringify!($method),
                    ") HTTP method.")]
        /// # Examples
        ///
        /// ```
        /// use cot::html::Html;
        /// use cot::router::method::MethodRouter;
        ///
        /// async fn test_handler() -> Html {
        ///     Html::new("test")
        /// }
        ///
        /// # #[tokio::main]
        /// # async fn main() -> cot::Result<()> {
        #[doc = concat!(
            "let method_router = MethodRouter::new().",
            stringify!($name),
            "(test_handler);"
        )]
        /// #
        /// # let router = cot::router::Router::with_urls(
        /// #     [cot::router::Route::with_handler("/", method_router)]
        /// # );
        /// #
        #[doc = concat!(
            "# let request = cot::test::TestRequestBuilder::with_method(\"/\", cot::Method::",
            stringify!($method),
            ")"
        )]
        /// #     .router(router.clone())
        /// #     .build();
        /// # assert_eq!(
        /// #     router
        /// #         .handle(request)
        /// #         .await?
        /// #         .into_body()
        /// #         .into_bytes()
        /// #         .await?,
        /// #     "test"
        /// # );
        /// # Ok(())
        /// # }
        /// ```
        pub fn $name<HandlerParams, H>(mut self, handler: H) -> Self
        where
            HandlerParams: 'static,
            H: RequestHandler<HandlerParams> + Send + Sync + 'static,
        {
            self.inner.$name = Some(InnerHandler::new(handler));
            self
        }
    };
}

impl Default for MethodRouter {
    fn default() -> Self {
        Self::new()
    }
}

impl MethodRouter {
    /// Create a new [`MethodRouter`].
    ///
    /// You might consider using [`get`], [`post`], or one of the other
    /// functions defined in [`cot::router::method`] which serve as convenient
    /// constructors for a [`MethodRouter`] with a specific handler.
    ///
    /// # Examples
    ///
    /// ```
    /// use cot::html::Html;
    /// use cot::router::method::MethodRouter;
    /// use cot::router::{Route, Router};
    /// use cot::test::TestRequestBuilder;
    ///
    /// async fn test_handler() -> Html {
    ///     Html::new("GET response")
    /// }
    ///
    /// # #[tokio::main]
    /// # async fn main() -> cot::Result<()> {
    /// let method_router = MethodRouter::new().get(test_handler);
    ///
    /// let router = Router::with_urls([Route::with_handler("/", method_router)]);
    ///
    /// let request = TestRequestBuilder::get("/").router(router.clone()).build();
    /// assert_eq!(
    ///     router
    ///         .handle(request)
    ///         .await?
    ///         .into_body()
    ///         .into_bytes()
    ///         .await?,
    ///     "GET response"
    /// );
    /// # Ok(())
    /// # }
    /// ```
    pub fn new() -> Self {
        Self {
            inner: InnerMethodRouter::new(),
        }
    }

    define_method!(get => GET);
    define_method!(head => HEAD);
    define_method!(delete => DELETE);
    define_method!(options => OPTIONS);
    define_method!(patch => PATCH);
    define_method!(post => POST);
    define_method!(put => PUT);
    define_method!(trace => TRACE);
    define_method!(connect => CONNECT);

    /// Set a fallback handler that gets called when no other handler matches.
    ///
    /// # Examples
    ///
    /// ```
    /// use cot::StatusCode;
    /// use cot::html::Html;
    /// use cot::response::IntoResponse;
    /// use cot::router::method::MethodRouter;
    /// use cot::router::{Route, Router};
    /// use cot::test::TestRequestBuilder;
    ///
    /// async fn fallback_handler() -> impl IntoResponse {
    ///     Html::new("Method Not Allowed").with_status(StatusCode::METHOD_NOT_ALLOWED)
    /// }
    ///
    /// # #[tokio::main]
    /// # async fn main() -> cot::Result<()> {
    /// let method_router = MethodRouter::new().fallback(fallback_handler);
    ///
    /// let router = Router::with_urls([Route::with_handler("/", method_router)]);
    ///
    /// let request = TestRequestBuilder::get("/").router(router.clone()).build();
    /// assert_eq!(
    ///     router
    ///         .handle(request)
    ///         .await?
    ///         .into_body()
    ///         .into_bytes()
    ///         .await?,
    ///     "Method Not Allowed"
    /// );
    /// # Ok(())
    /// # }
    /// ```
    pub fn fallback<HandlerParams, H>(mut self, handler: H) -> Self
    where
        HandlerParams: 'static,
        H: RequestHandler<HandlerParams> + Send + Sync + 'static,
    {
        self.inner.fallback = InnerHandler::new(handler);
        self
    }
}

impl RequestHandler for MethodRouter {
    fn handle(&self, request: Request) -> impl Future<Output = cot::Result<Response>> + Send {
        self.inner.handle(request)
    }
}

#[derive(Debug)]
#[must_use]
struct InnerMethodRouter<T> {
    pub(self) get: Option<T>,
    pub(self) head: Option<T>,
    pub(self) delete: Option<T>,
    pub(self) options: Option<T>,
    pub(self) patch: Option<T>,
    pub(self) post: Option<T>,
    pub(self) put: Option<T>,
    pub(self) trace: Option<T>,
    // CONNECT can't be used in OpenAPI, so it's always a base handler
    pub(self) connect: Option<InnerHandler>,
    pub(self) fallback: InnerHandler,
}

impl<T> InnerMethodRouter<T> {
    pub(crate) fn new() -> Self {
        Self {
            get: None,
            head: None,
            delete: None,
            options: None,
            patch: None,
            post: None,
            put: None,
            trace: None,
            connect: None,
            fallback: InnerHandler::new(default_fallback),
        }
    }
}

impl<T: RequestHandler + Send + Sync> RequestHandler for InnerMethodRouter<T> {
    async fn handle(&self, request: Request) -> cot::Result<Response> {
        macro_rules! handle_method {
            ($name:ident => $method:ident) => {
                if request.method() == Method::$method {
                    if let Some(handler) = &self.$name {
                        return handler.handle(request).await;
                    }
                }
            };
        }

        handle_method!(get => GET);
        handle_method!(head => HEAD);
        handle_method!(delete => DELETE);
        handle_method!(options => OPTIONS);
        handle_method!(patch => PATCH);
        handle_method!(post => POST);
        handle_method!(put => PUT);
        handle_method!(trace => TRACE);
        handle_method!(connect => CONNECT);

        if request.method() == Method::HEAD {
            // handle HEAD requests by calling the GET handler
            // if the HEAD handler is not set

            if let Some(handler) = &self.get {
                return handler.handle(request).await;
            }
        }

        self.fallback.handle(request).await
    }
}

struct InnerHandler(Box<dyn BoxRequestHandler + Send + Sync>);

impl InnerHandler {
    fn new<HandlerParams, H>(handler: H) -> Self
    where
        HandlerParams: 'static,
        H: RequestHandler<HandlerParams> + Send + Sync + 'static,
    {
        Self(Box::new(into_box_request_handler(handler)))
    }
}

impl Debug for InnerHandler {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        f.debug_tuple("InnerHandler").finish_non_exhaustive()
    }
}

impl RequestHandler for InnerHandler {
    fn handle(&self, request: Request) -> impl Future<Output = cot::Result<Response>> + Send {
        self.0.handle(request)
    }
}

macro_rules! define_method_router {
    ($name:ident => $method:ident) => {
        #[doc = concat!(
                    "Create a new [`MethodRouter`] with a [`",
                    stringify!($method),
                    "`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Methods/",
                    stringify!($method),
                    ") handler."
                )]
        ///
        #[doc = concat!(
            "This is a shorthand to calling [`MethodRouter::new`] and then [`MethodRouter::",
            stringify!($name),
            "`]."
        )]
        /// # Examples
        ///
        /// ```
        /// use cot::html::Html;
        #[doc = concat!("use cot::router::method::", stringify!($name), ";")]
        ///
        /// async fn test_handler() -> cot::Result<Html> {
        ///     Ok(Html::new("test"))
        /// }
        ///
        /// # #[tokio::main]
        /// # async fn main() -> cot::Result<()> {
        #[doc = concat!("let method_router = ", stringify!($name), "(test_handler);")]
        /// #
        /// # let router = cot::router::Router::with_urls(
        /// #     [cot::router::Route::with_handler("/", method_router)]
        /// # );
        /// #
        #[doc = concat!(
            "# let request = cot::test::TestRequestBuilder::with_method(\"/\", cot::Method::",
            stringify!($method),
            ")"
        )]
        /// #     .router(router.clone())
        /// #     .build();
        /// # assert_eq!(
        /// #     router
        /// #         .handle(request)
        /// #         .await?
        /// #         .into_body()
        /// #         .into_bytes()
        /// #         .await?,
        /// #     "test"
        /// # );
        /// # Ok(())
        /// # }
        /// ```
        pub fn $name<HandlerParams, H>(handler: H) -> MethodRouter
        where
            HandlerParams: 'static,
            H: RequestHandler<HandlerParams> + Send + Sync + 'static,
        {
            MethodRouter::new().$name(handler)
        }
    };
}

define_method_router!(get => GET);
define_method_router!(head => HEAD);
define_method_router!(delete => DELETE);
define_method_router!(options => OPTIONS);
define_method_router!(patch => PATCH);
define_method_router!(post => POST);
define_method_router!(put => PUT);
define_method_router!(trace => TRACE);
define_method_router!(connect => CONNECT);

async fn default_fallback(method: Method) -> crate::Error {
    MethodNotAllowed::new(method).into()
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::StatusCode;
    use crate::html::Html;
    use crate::test::TestRequestBuilder;

    async fn test_handler(method: Method) -> Html {
        Html::new(method.as_str())
    }

    #[test]
    fn inner_handler_debug() {
        let handler = InnerHandler::new(test_handler);

        let debug_str = format!("{handler:?}");

        assert_eq!(debug_str, "InnerHandler(..)");
    }

    #[cot::test]
    async fn method_router_fallback() {
        let router = MethodRouter::new();

        let request = TestRequestBuilder::get("/").build();
        let response = router.handle(request).await.unwrap_err();
        let inner = response.inner();

        assert_eq!(inner.status_code(), StatusCode::METHOD_NOT_ALLOWED);
        assert!(inner.is::<MethodNotAllowed>());
    }

    #[cot::test]
    async fn method_router_default_fallback() {
        let router = MethodRouter::default();

        let request = TestRequestBuilder::get("/").build();
        let response = router.handle(request).await.unwrap_err();
        let inner = response.inner();

        assert_eq!(inner.status_code(), StatusCode::METHOD_NOT_ALLOWED);
        assert!(inner.is::<MethodNotAllowed>());
    }

    #[cot::test]
    async fn method_router_custom_fallback() {
        let router = MethodRouter::new().fallback(test_handler);

        let request = TestRequestBuilder::get("/").build();
        let response = router.handle(request).await.unwrap();

        assert_eq!(response.status(), StatusCode::OK);
        assert_eq!(response.into_body().into_bytes().await.unwrap(), "GET");
    }

    #[cot::test]
    async fn method_router_get() {
        let router = get(test_handler);

        let request = TestRequestBuilder::get("/").build();
        let response = router.handle(request).await.unwrap();

        assert_eq!(response.status(), StatusCode::OK);

        // check other methods
        let methods = [
            Method::DELETE,
            Method::OPTIONS,
            Method::PATCH,
            Method::POST,
            Method::PUT,
            Method::TRACE,
            Method::CONNECT,
        ];
        for method in methods {
            let request = TestRequestBuilder::with_method("/", method).build();
            let response = router.handle(request).await.unwrap_err();
            let inner = response.inner();

            assert_eq!(inner.status_code(), StatusCode::METHOD_NOT_ALLOWED);
            assert!(inner.is::<MethodNotAllowed>());
        }
    }

    macro_rules! test_method_router {
        ($test_name:ident, $constructor_name:ident, $method_name:ident) => {
            #[cot::test]
            async fn $test_name() {
                let router = $constructor_name(test_handler);

                let request = TestRequestBuilder::with_method("/", Method::$method_name).build();
                let response = router.handle(request).await.unwrap();

                assert_eq!(response.status(), StatusCode::OK);
            }
        };
    }

    test_method_router!(method_router_head, head, HEAD);
    test_method_router!(method_router_delete, delete, DELETE);
    test_method_router!(method_router_options, options, OPTIONS);
    test_method_router!(method_router_patch, patch, PATCH);
    test_method_router!(method_router_post, post, POST);
    test_method_router!(method_router_put, put, PUT);
    test_method_router!(method_router_trace, trace, TRACE);
    test_method_router!(method_router_connect, connect, CONNECT);

    #[cot::test]
    async fn method_router_default_head() {
        // verify that the default method router doesn't handle HEAD
        let router = MethodRouter::new();

        let request = TestRequestBuilder::with_method("/", Method::HEAD).build();
        let response = router.handle(request).await.unwrap_err();
        let inner = response.inner();

        assert_eq!(inner.status_code(), StatusCode::METHOD_NOT_ALLOWED);
        assert!(inner.is::<MethodNotAllowed>());

        // check that if GET handler is defined, HEAD is routed to it
        let router = get(test_handler);

        let request = TestRequestBuilder::with_method("/", Method::HEAD).build();
        let response = router.handle(request).await.unwrap();

        assert_eq!(response.status(), StatusCode::OK);
    }

    #[cot::test]
    async fn method_router_multiple() {
        let router = MethodRouter::new()
            .get(test_handler)
            .head(test_handler)
            .delete(test_handler)
            .options(test_handler)
            .patch(test_handler)
            .post(test_handler)
            .put(test_handler)
            .trace(test_handler)
            .connect(test_handler);

        for (method, expected_string) in [
            (Method::GET, "GET"),
            (Method::HEAD, "HEAD"),
            (Method::DELETE, "DELETE"),
            (Method::OPTIONS, "OPTIONS"),
            (Method::PATCH, "PATCH"),
            (Method::POST, "POST"),
            (Method::PUT, "PUT"),
            (Method::TRACE, "TRACE"),
            (Method::CONNECT, "CONNECT"),
        ] {
            let request = TestRequestBuilder::with_method("/", method).build();
            let response = router.handle(request).await.unwrap();

            assert_eq!(response.status(), StatusCode::OK);
            assert_eq!(
                response.into_body().into_bytes().await.unwrap(),
                expected_string
            );
        }
    }
}
